import unittest
import numpy as np
from xinfer_yolo.simd.avx_cpy import avx_mul


class TestAVX(unittest.TestCase):
    def test_basic(self):
        n = 400
        a = np.random.random(n).astype(np.float32)
        b = np.random.random(n).astype(np.float32)
        c = np.zeros(n, dtype=np.float32)
        d = np.zeros(n, dtype=np.float32)

        for i in range(0, n, 4):
            avx_mul(a[i:], b[i:], c[i:])
        d = a * b

        self.assertListEqual(c.tolist(), d.tolist())